{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6.7 门控循环单元GRU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "import torch.nn.functional as F\n",
    "\n",
    "import sys\n",
    "import d2lzh_pytorch as d2l\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "will use cuda\n"
     ]
    }
   ],
   "source": [
    "(corpus_indices, char_to_idx, idx_to_char, vocab_size) = \\\n",
    "    d2l.load_data_jay_lyrics()\n",
    "num_inputs, num_hiddens, num_outputs = vocab_size, 256, vocab_size\n",
    "print('will use', device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_params():\n",
    "    def _one(shape):\n",
    "        ts = torch.tensor(np.random.normal(0, 0.01, size=shape), \\\n",
    "                          device=device, dtype=torch.float32)\n",
    "        return torch.nn.Parameter(ts, requires_grad=True)\n",
    "    \n",
    "    def _three():\n",
    "        return (_one((num_inputs, num_hiddens)),\n",
    "               _one((num_hiddens, num_hiddens)),\n",
    "               torch.nn.Parameter(torch.zeros(num_hiddens, device=device,\n",
    "                                dtype=torch.float32), requires_grad=True))\n",
    "    \n",
    "    W_xz, W_hz, b_z = _three()    # 更新门参数\n",
    "    W_xr, W_hr, b_r = _three()    # 重置门参数\n",
    "    W_xh, W_hh, b_h = _three()    # 候选隐藏状态参数\n",
    "    \n",
    "    # 输出层参数\n",
    "    W_hq = _one((num_hiddens, num_outputs))\n",
    "    b_q = torch.nn.Parameter(torch.zeros(num_outputs,\n",
    "                device=device, dtype=torch.float32), requires_grad=True)\n",
    "    return nn.ParameterList([W_xz, W_hz, b_z, W_xr, W_hr, b_r, \n",
    "                            W_xh, W_hh, b_h, W_hq, b_q])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def init_gru_state(batch_size, num_hiddens, device):\n",
    "    return (torch.zeros((batch_size, num_hiddens), device=device))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gru(inputs, state, params):\n",
    "    W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params\n",
    "    H, = state\n",
    "    outputs = []\n",
    "    for X in inputs:\n",
    "        Z = torch.sigmoid(torch.matmul(X, W_xz) + torch.matmul(H, W_hz) + b_z)\n",
    "        R = torch.sigmoid(torch.matmul(X, W_xr) + torch.matmul(H, W_hr) + b_r)\n",
    "        H_tilda = torch.tanh(torch.matmul(X, W_xh) + R *\n",
    "                            torch.matmul(H, W_hh) + b_h)\n",
    "        H = Z * H + (1 - Z) * H_tilda\n",
    "        Y = torch.matmul(H, W_hq) + b_q\n",
    "        outputs.append(Y)\n",
    "    return outputs, (H,)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "too many values to unpack (expected 1)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-10-b85c0a276613>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      4\u001b[0m             \u001b[0mvocab_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcorpus_indices\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0midx_to_char\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mchar_to_idx\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m             \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_epochs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_steps\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclipping_theta\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m             pred_period, pred_len, prefixes)\n\u001b[0m",
      "\u001b[0;32m/mnt/data/git-repository/Learning-ML/dive_into_ml/d2lzh_pytorch.py\u001b[0m in \u001b[0;36mtrain_and_predict_rnn\u001b[0;34m(rnn, get_params, init_rnn_state, num_hiddens, vocab_size, device, corpus_indices, idx_to_char, char_to_idx, is_random_iter, num_epochs, num_steps, lr, clipping_theta, batch_size, pred_period, pred_len, prefixes)\u001b[0m\n\u001b[1;32m    452\u001b[0m             \u001b[0minputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mto_onehot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvocab_size\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    453\u001b[0m             \u001b[0;31m# outputs有num_steps个形状为(batch_size, vocab_size)的矩阵\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 454\u001b[0;31m             \u001b[0;34m(\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrnn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    455\u001b[0m             \u001b[0;31m# 拼接之后形状为(num_steps * batch_size, vocab_size)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    456\u001b[0m             \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-9-cbd14b29b934>\u001b[0m in \u001b[0;36mgru\u001b[0;34m(inputs, state, params)\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mgru\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      2\u001b[0m     \u001b[0mW_xz\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mW_hz\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb_z\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mW_xr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mW_hr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb_r\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mW_xh\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mW_hh\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb_h\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mW_hq\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb_q\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m     \u001b[0mH\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      4\u001b[0m     \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m     \u001b[0;32mfor\u001b[0m \u001b[0mX\u001b[0m \u001b[0;32min\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mValueError\u001b[0m: too many values to unpack (expected 1)"
     ]
    }
   ],
   "source": [
    "num_epochs, num_steps, batch_size, lr, clipping_theta = 160, 35, 16, 1e2, 1e-2\n",
    "pred_period, pred_len, prefixes = 40, 50, ['分开', '不分开']\n",
    "d2l.train_and_predict_rnn(gru, get_params, init_gru_state, num_hiddens,\n",
    "            vocab_size, device, corpus_indices, idx_to_char, char_to_idx,\n",
    "            False, num_epochs, num_steps, lr, clipping_theta, batch_size,\n",
    "            pred_period, pred_len, prefixes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
